import optuna
from torch.optim import Adam
import torch
from Dataset import load
import argparse
from torch.nn.functional import l1_loss, mse_loss
from reform.QM9NeiborEmbDirSchNet import DirSchNet
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from reform import Utils
import numpy as np
import time
from torch.utils.data import TensorDataset, DataLoader, random_split

EPS = 1e-6
step_plot = -1

ratio_y = 0.01
ratio_dy = 1


def buildModel(**kwargs):
    if args.dataset == "dipole_moment":
        tar = "dipole_moment"
    elif args.dataset == "electronic_spatial_extent":
        tar = "electronic_spatial_extent"
    else:
        tar = "scalar"
    mod = DirSchNet(y_mean=y_mean,
                    y_std=y_std,
                    global_y_mean=global_y_mean,
                    tar=tar,
                    **kwargs)
    print(f"numel {sum(p.numel() for p in mod.parameters() if p.requires_grad)}")
    return mod


parser = argparse.ArgumentParser(description='')
parser.add_argument('--dataset', type=str, default="benzene")
parser.add_argument('--modname', type=str, default="0")
parser.add_argument('--test', action="store_true")
parser.add_argument('--test_mod', action="store_true")
args = parser.parse_args()

modfilename = f"save_mod/{args.dataset}.dirschnet.pt"

device = torch.device("cuda")
dataset = load(args.dataset)
if args.dataset in [
        "dipole_moment", "isotropic_polarizability", "homo", "lumo", "gap",
        "electronic_spatial_extent", "zpve", "energy_U0", "energy_U",
        "enthalpy_H", "free_energy", "heat_capacity"
]:
    ratio = [110000, 10000, 10831]
else:
    raise NotImplementedError
N = dataset["z"].shape[1]
ds = TensorDataset(dataset['z'].reshape(-1,
                                        N), dataset['pos'].reshape(-1, N, 3),
                   dataset['y'].reshape(-1, 1))
y_mean = None
y_std = None
global_y_mean = 0.0


def work(total_step: int = 4000,
         batch_size: int = 256,
         save_model: bool = False,
         do_test: bool = False,
         jump_train: bool = False,
         max_early_stop: int = 100,
         lr: float = 1e-3,
         warmup: int = 3,
         patience: int = 10,
         **kwargs):
    global y_mean, y_std, ratio_y
    if "ratio_y" in kwargs:
        ratio_y = kwargs["ratio_y"]

    NAN_PANITY = 1e3
    trn_ds, val_ds, tst_ds = random_split(ds, ratio)
    trn_dl = Utils.tensorDataloader(trn_ds.dataset[trn_ds.indices], batch_size,
                                    True, device, True)
    val_dl = Utils.tensorDataloader(val_ds.dataset[val_ds.indices], 4*batch_size,
                                    False, device, False)
    tst_dl = Utils.tensorDataloader(tst_ds.dataset[tst_ds.indices], 4*batch_size,
                                    False, device, False)
    y_mean = torch.mean(trn_ds.dataset[trn_ds.indices][2]).item()
    y_std = torch.std(trn_ds.dataset[trn_ds.indices][2]).item()
    mod = buildModel(**kwargs).to(device)
    best_val_loss = float("inf")
    if not jump_train:
        opt = Adam(mod.parameters(), lr=lr / 100)
        scd1 = StepLR(opt,
                      1,
                      gamma=100**(1 / (warmup * (110000 // batch_size))))
        scd2 = ReduceLROnPlateau(opt,
                                 mode="min",
                                 factor=0.6,
                                 patience=patience,
                                 min_lr=1e-6)  #5e-5)
        early_stop = 0
        for epoch in range(total_step):
            curlr = opt.param_groups[0]["lr"]
            trn_losss = []
            for batch in trn_dl:
                trn_loss_y = Utils.train(batch, opt, mod, mse_loss)
                if np.isnan(trn_loss_y):
                    return NAN_PANITY
                if epoch < warmup:
                    scd1.step()
                trn_losss.append(trn_loss_y)
            trn_loss_y = np.average(trn_losss)

            val_loss = Utils.testdl(val_dl, mod, l1_loss)
            if epoch > warmup:
                scd2.step(val_loss)
            early_stop += 1
            if np.isnan(val_loss):
                return NAN_PANITY
            if val_loss < best_val_loss:
                early_stop = 0
                best_val_loss = val_loss
                if save_model:
                    torch.save(mod.state_dict(), modfilename)
                tst_score = Utils.testdl(tst_dl, mod, l1_loss)
                print(f"tst E {tst_score:.4f}")
            if early_stop > max_early_stop:
                break
            print(
                f"iter {epoch} lr {curlr:.4e} trn E {trn_loss_y:.4f} val E {val_loss:.4f}",
                flush=True)
    
    if do_test:
        mod.load_state_dict(torch.load(modfilename, map_location="cpu"))
        mod = mod.to(device)
        tst_score = Utils.testdl(tst_dl, mod, l1_loss)
        trn_score = Utils.testdl(trn_dl, mod, l1_loss)
        val_score = Utils.testdl(val_dl, mod, l1_loss)
        print(trn_score, val_score, tst_score)
    return min(best_val_loss, NAN_PANITY)


fixed_p = {
    "batch_size": 64,
    "hid_dim": 256,
    "lr": 0.0003,
    "patience": 20,
    "rbf": "nexpnorm",
    "rbound_lower": 0.0,
    "warmup": 20,
    'ln_lin1': True,
    'ln_s2v': True,
    'max_z': 20,
    'wd': 0,
    'lin1_tailact': True,
}


def search(trial: optuna.Trial):
    ln_emb = trial.suggest_categorical("ln_emb", [True, False])
    num_mplayer = trial.suggest_int("num_mplayer", 4, 8, step=2)
    cutoff = trial.suggest_int("cutoff", 4, 10, step=0.5)
    ef_dim = trial.suggest_int("ef_dim", 32, 128, step=16)
    add_ef2dir = trial.suggest_categorical("add_ef2dir", [True, False])
    ef_decay = trial.suggest_categorical("ef_decay", [True, False])
    ev_decay = trial.suggest_categorical("ev_decay", [True, False])
    dir2mask_tailact = trial.suggest_categorical("dir2mask_tailact",
                                                 [True, False])
    ef2mask_tailact = trial.suggest_categorical("ef2mask_tailact",
                                                [True, False])
    ret = work(**fixed_p,
               ln_emb=ln_emb,
               num_mplayer=num_mplayer,
               cutoff=cutoff,
               ef_dim=ef_dim,
               add_ef2dir=add_ef2dir,
               ef_decay=ef_decay,
               ev_decay=ev_decay,
               dir2mask_tailact=dir2mask_tailact,
               ef2mask_tailact=ef2mask_tailact,
               total_step=500,
               max_early_stop=50,
               search_hp=True)
    print("", flush=True)
    return ret


study = optuna.create_study(direction="minimize",
                            storage="sqlite:///" + "Opt/" + args.dataset +
                            "lr.db",
                            study_name=args.dataset,
                            load_if_exists=True)

if args.test:
    t1 = time.time()
    for i in range(3):
        print(f"seed {i}")
        Utils.set_seed(i)
        modfilename = f"save_mod/{args.dataset}.dirschnet.{args.modname}.{i}.pt"
        import qm9_params
        tp = qm9_params.param
        print(
            work(**tp,
             total_step=1000,
             max_early_stop=100,
             save_model=True,
             do_test=True))
    print(f"time {time.time()-t1:.2f} s")
else:
    study.optimize(search, n_trials=200)
    print("best params ", study.best_params)
    print("best valf1 ", study.best_value)
